from loadLabels import LoadLabels as lls

import numpy as np
import cv2
import math
import random

def compute_iou(rect1, rect2):
    """
    computing IoU
    :param rec1: (x, y, w, h), which reflects
            (top, left, bottom, right)
    :param rec2: (y0, x0, y1, x1)
    :return: scala value of IoU
    """
    # computing area of each rectangles
    rec1 = [rect1[1], rect1[0], rect1[1] + rect1[3], rect1[0] + rect1[2]]
    rec2 = [rect2[1], rect2[0], rect2[1] + rect2[3], rect2[0] + rect2[2]]
    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
 
    # computing the sum_area
    sum_area = S_rec1 + S_rec2
 
    # find the each edge of intersect rectangle
    left_line = max(rec1[1], rec2[1])
    right_line = min(rec1[3], rec2[3])
    top_line = max(rec1[0], rec2[0])
    bottom_line = min(rec1[2], rec2[2])
 
    # judge if there is an intersect
    if left_line >= right_line or top_line >= bottom_line:
        return 0
    else:
        intersect = (right_line - left_line) * (bottom_line - top_line)
        return intersect / (sum_area - intersect)

def isBoxinBox(box, start_x, start_y, end_x, end_y):
    return box[0] >= start_x and box[1] >= start_y and (box[2] + box[0]) < end_x and (box[3] + box[1]) < end_y

def isBoxOutofRange(box, w_low, h_low, w_high, h_high):
    return (box[2] < w_low) or (box[3] < h_low) or (box[2] > w_high) or (box[3] > h_high)

def labelFunV1(label, start_x, start_y, end_x, end_y):
    patches = []
    for box in label:
        # print(box, start_x, start_y, end_x, end_y)
        if isBoxinBox(box, start_x, start_y, end_x, end_y):
            # print('1')
            relative_box = np.array(box)  - np.array([start_x, start_y, 0, 0])
            if isBoxOutofRange(relative_box, 10, 10, 400, 400):
                continue                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
            patches.append(relative_box)
    return patches

def addNegtiveYOLT(boxes, p):
    boxes_size = len(boxes)
    supposed_num = 10
    for box in boxes:
        if random.random() > min(0.5, (supposed_num + 0.1) / boxes_size):
            continue

        x, y = box[0], box[1]
        w, h = box[2], box[3]
        new_box = [None] * 8
        new_box[0] = [x - w, y - h, w, h]
        new_box[1] = [x - w, y, w, h]
        new_box[2] = [x - w, y + h, w, h]
        new_box[3] = [x, y + h, w, h]
        new_box[4] = [x + w, y + h, w, h]
        new_box[5] = [x + w, y, w, h]
        new_box[6] = [x + w, y - h, w, h]
        new_box[7] = [x, y - h, w, h]

        for i in range(8):
            if new_box[i]:
                new_box[i][0] = max(0, new_box[i][0])
                new_box[i][0] = min(415, new_box[i][0])
                new_box[i][1] = max(0, new_box[i][1])
                new_box[i][1] = min(415, new_box[i][1])
                new_box[i][2] = max(0, new_box[i][2])
                new_box[i][2] = min(415, new_box[i][2])
                new_box[i][2] = min(415 - new_box[i][0], new_box[i][2])
                new_box[i][3] = max(0, new_box[i][3])
                new_box[i][3] = min(415, new_box[i][3])
                new_box[i][3] = min(415 - new_box[i][1], new_box[i][3])
                for _box in boxes:
                    if compute_iou(new_box[i], _box) > 0.01:
                        new_box[i] = None
                        break
                if new_box[i]:
                    if random.random() < p:
                        boxes.append(new_box[i])

def converLabelYOLT(label, is_addN=True, cells=26, boxes=2, length=16, new_boxes=None):
    '''
    说明
    ----
    用于按照YOLT格式转换标签
    
    参数
    ----
    - label 标签集，是原始图像裁剪后的的标签，labelFunV1()得到
    - cells 一排的格子数
    - boxes 每个格子的预测框数
    - length 每个格子的长度
    '''
    all_boxes = cells * cells * boxes
    elements_box = 5

    patches = []
    for img in label:
        patch = [0] * (all_boxes * elements_box)
        good_index = len(img)
        if is_addN:
            addNegtiveYOLT(img, 0.2)
            all_index = len(img)
            # print("all: ", good_index, "add: ", all_index)
            new_boxes.append(img)
        for i, box in enumerate(img):
            _x = box[0] % length
            _y = box[1] % length
            _w = box[2]
            _h = box[3]
            if i < good_index:
                _c = 1
            else:
                _c = -1
            _index = (int(box[0] / length) + int(int(box[1] / length) * cells)) * (elements_box * boxes) #下标序列
            _loop = boxes - 1
            # print(_index + 4)
            while (_loop > 0) and (1 == patch[_index + 4]):
                _index += 5
                _loop -= 1
            patch[_index + 0] = _x
            patch[_index + 1] = _y
            patch[_index + 2] = _w
            patch[_index + 3] = _h
            patch[_index + 4] = _c
            # print(box, _index, _x, _y, _w, _h, _c)
        patch = np.asarray(patch, dtype='float')
        patch = patch.reshape((cells, cells, boxes * elements_box))
        patches.append(patch)
    return patches

def cutImage(image, label, labelFun, name, scale=416, stride=0.3):
    '''
    说明
    ----
    裁剪图像，将输入图像image， 裁剪成scale × scale大小的图像，步长为scale × stride，将输入label按照函数labelFun处理
    
    注意
    ----
    图像大小不够时，用黑色补充
    
    参数
    ----
    - image 输入RBG图像
    - label 标签
    - labelFun 标签处理函数
    - name 用于区别图像的名称
    - scale 裁剪后图像的大小
    - stride 裁剪移动步长
    '''
    if name:
        print(name)
    # print(image.shape[:-1])

    height, width = image.shape[:-1]
    step = scale * stride

    if height < scale or width < scale:
        _height, _width = height, width
        height = scale if height < scale else height
        width = scale if width < scale else width
        _image = np.ones((height, width, 3), dtype="uint8")
        _image[0:_height, 0:_width] = image
        image = _image
        # cv2.imshow('test', image)
        # print(image.shape)
        # cv2.waitKey(0)

    patches_img = {}
    patches_label = []
    start_x, end_x = 0, 0
    start_y, end_y = 0, 0
    loop_x = math.ceil((width - scale) / step) + 1
    loop_y = math.ceil((height - scale) / step) + 1
    # print(loop_x, loop_y)
    for y in range(loop_y):
        start_y = int((height - scale) if (y * step + scale) >= height else y * step)
        end_y = start_y + scale
        for x in range(loop_x):
            start_x = int((width - scale) if (x * step + scale) >= width else x * step)
            end_x = start_x + scale
            cropped = image[start_y:end_y, start_x:end_x]
            patches_img[(start_x, start_y)] = cropped
            patches_label.append(labelFun(label, start_x, start_y, end_x, end_y))
            
    return patches_img, patches_label

if __name__ == '__main__':
    label_file = "/home/jerry/dataSet/faces/wider_face_split/wider_face_train_bbx_gt.txt"
    root_dir = "/home/jerry/dataSet/faces/WIDER_train/images/"
    
    labels = lls(label_file, root_dir, 5555, 5555)
    while True:
        if not labels.load():
            break
        image_file_name = labels.getFileName()
        image = cv2.imread(image_file_name)

        times = 1

        label = labels.getBoundingBoxesListX(times)
        image = cv2.resize(image, (int(image.shape[1] * times), int(image.shape[0] * times)))
        cv2.imshow('image',image)
        cv2.waitKey(10)
        patches_img, patches_label = cutImage(image, label, labelFunV1, str(times) + ': ' + image_file_name)
        # print(patches_label)
        good_index = []
        for label in patches_label:
            good_index.append(len(label))
        new_boxes = []
        new_label = converLabelYOLT(patches_label, new_boxes=new_boxes)
        for i, img in enumerate(patches_img.values()):
            for j, box in enumerate(new_boxes[i]):
                # print(box)
                if j < good_index[i]:
                    cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[0] + box[2]), int(box[1] + box[3])), (0, 255, 0))
                else:
                    cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[0] + box[2]), int(box[1] + box[3])), (255, 0, 0))
            cv2.imshow('cut', img)
            cv2.waitKey(0)